{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hmQcJgGH0TKw"
      },
      "source": [
        "# Including Node Features with Multilayer Perceptrons"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JoN3GTmw0CTs",
        "outputId": "c4b7ebbc-79d6-44d3-b448-7006f2b586f4"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "!pip install -q torch-scatter~=2.1.0 torch-sparse~=0.6.16 torch-cluster~=1.6.0 torch-spline-conv~=1.2.1 torch-geometric==2.2.0 -f https://data.pyg.org/whl/torch-{torch.__version__}.html\n",
        "\n",
        "torch.manual_seed(0)\n",
        "torch.cuda.manual_seed(0)\n",
        "torch.cuda.manual_seed_all(0)\n",
        "torch.backends.cudnn.deterministic = True\n",
        "torch.backends.cudnn.benchmark = False"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 27,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eMp_ertkDcJf",
        "outputId": "33e608a1-9e1e-4841-e835-20f5dff1b437"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Dataset: Cora()\n",
            "---------------\n",
            "Number of graphs: 1\n",
            "Number of nodes: 2708\n",
            "Number of features: 1433\n",
            "Number of classes: 7\n",
            "\n",
            "Graph:\n",
            "------\n",
            "Edges are directed: False\n",
            "Graph has isolated nodes: False\n",
            "Graph has loops: False\n"
          ]
        }
      ],
      "source": [
        "from torch_geometric.datasets import Planetoid\n",
        "\n",
        "# Import dataset from PyTorch Geometric\n",
        "dataset = Planetoid(root=\".\", name=\"Cora\")\n",
        "\n",
        "data = dataset[0]\n",
        "\n",
        "# Print information about the dataset\n",
        "print(f'Dataset: {dataset}')\n",
        "print('---------------')\n",
        "print(f'Number of graphs: {len(dataset)}')\n",
        "print(f'Number of nodes: {data.x.shape[0]}')\n",
        "print(f'Number of features: {dataset.num_features}')\n",
        "print(f'Number of classes: {dataset.num_classes}')\n",
        "\n",
        "# Print information about the graph\n",
        "print(f'\\nGraph:')\n",
        "print('------')\n",
        "print(f'Edges are directed: {data.is_directed()}')\n",
        "print(f'Graph has isolated nodes: {data.has_isolated_nodes()}')\n",
        "print(f'Graph has loops: {data.has_self_loops()}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 28,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "c-PceLVMoTkM",
        "outputId": "57ad9c83-9179-4e52-a039-a6ea21f82e06"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Dataset: FacebookPagePage()\n",
            "-----------------------\n",
            "Number of graphs: 1\n",
            "Number of nodes: 22470\n",
            "Number of features: 128\n",
            "Number of classes: 4\n",
            "\n",
            "Graph:\n",
            "------\n",
            "Edges are directed: False\n",
            "Graph has isolated nodes: False\n",
            "Graph has loops: True\n"
          ]
        }
      ],
      "source": [
        "from torch_geometric.datasets import FacebookPagePage\n",
        "\n",
        "# Import dataset from PyTorch Geometric\n",
        "dataset = FacebookPagePage(root=\".\")\n",
        "\n",
        "data = dataset[0]\n",
        "\n",
        "# Print information about the dataset\n",
        "print(f'Dataset: {dataset}')\n",
        "print('-----------------------')\n",
        "print(f'Number of graphs: {len(dataset)}')\n",
        "print(f'Number of nodes: {data.x.shape[0]}')\n",
        "print(f'Number of features: {dataset.num_features}')\n",
        "print(f'Number of classes: {dataset.num_classes}')\n",
        "\n",
        "# Print information about the graph\n",
        "print(f'\\nGraph:')\n",
        "print('------')\n",
        "print(f'Edges are directed: {data.is_directed()}')\n",
        "print(f'Graph has isolated nodes: {data.has_isolated_nodes()}')\n",
        "print(f'Graph has loops: {data.has_self_loops()}')\n",
        "\n",
        "# Create masks\n",
        "data.train_mask = range(18000)\n",
        "data.val_mask = range(18001, 20000)\n",
        "data.test_mask = range(20001, 22470)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Cora dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 29,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 424
        },
        "id": "AtFTuxGObVbX",
        "outputId": "1f78cb2a-0db3-4b03-d9fd-d1e8321b954c"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>0</th>\n",
              "      <th>1</th>\n",
              "      <th>2</th>\n",
              "      <th>3</th>\n",
              "      <th>4</th>\n",
              "      <th>5</th>\n",
              "      <th>6</th>\n",
              "      <th>7</th>\n",
              "      <th>8</th>\n",
              "      <th>9</th>\n",
              "      <th>...</th>\n",
              "      <th>1424</th>\n",
              "      <th>1425</th>\n",
              "      <th>1426</th>\n",
              "      <th>1427</th>\n",
              "      <th>1428</th>\n",
              "      <th>1429</th>\n",
              "      <th>1430</th>\n",
              "      <th>1431</th>\n",
              "      <th>1432</th>\n",
              "      <th>label</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>...</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>3</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>...</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>4</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>...</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>4</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>...</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>...</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>3</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>...</th>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2703</th>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>...</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>3</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2704</th>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>...</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>3</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2705</th>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>...</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>3</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2706</th>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>...</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>3</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2707</th>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>...</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>3</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "<p>2708 rows × 1434 columns</p>\n",
              "</div>"
            ],
            "text/plain": [
              "        0    1    2    3    4    5    6    7    8    9  ...  1424  1425  1426  \\\n",
              "0     0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...   0.0   0.0   0.0   \n",
              "1     0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...   0.0   0.0   0.0   \n",
              "2     0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...   0.0   0.0   0.0   \n",
              "3     0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...   0.0   0.0   0.0   \n",
              "4     0.0  0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  ...   0.0   0.0   0.0   \n",
              "...   ...  ...  ...  ...  ...  ...  ...  ...  ...  ...  ...   ...   ...   ...   \n",
              "2703  0.0  0.0  0.0  0.0  0.0  1.0  0.0  0.0  0.0  0.0  ...   0.0   1.0   0.0   \n",
              "2704  0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...   0.0   0.0   0.0   \n",
              "2705  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...   0.0   0.0   0.0   \n",
              "2706  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...   0.0   0.0   0.0   \n",
              "2707  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...   0.0   0.0   0.0   \n",
              "\n",
              "      1427  1428  1429  1430  1431  1432  label  \n",
              "0      0.0   0.0   0.0   0.0   0.0   0.0      3  \n",
              "1      0.0   0.0   0.0   0.0   0.0   0.0      4  \n",
              "2      0.0   0.0   0.0   0.0   0.0   0.0      4  \n",
              "3      0.0   0.0   0.0   0.0   0.0   0.0      0  \n",
              "4      0.0   0.0   0.0   0.0   0.0   0.0      3  \n",
              "...    ...   ...   ...   ...   ...   ...    ...  \n",
              "2703   0.0   0.0   0.0   0.0   0.0   0.0      3  \n",
              "2704   0.0   0.0   0.0   0.0   0.0   0.0      3  \n",
              "2705   0.0   0.0   0.0   0.0   0.0   0.0      3  \n",
              "2706   0.0   0.0   0.0   0.0   0.0   0.0      3  \n",
              "2707   0.0   0.0   0.0   0.0   0.0   0.0      3  \n",
              "\n",
              "[2708 rows x 1434 columns]"
            ]
          },
          "execution_count": 29,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import pandas as pd\n",
        "\n",
        "dataset = Planetoid(root=\".\", name=\"Cora\")\n",
        "data = dataset[0]\n",
        "\n",
        "df_x = pd.DataFrame(data.x.numpy())\n",
        "df_x['label'] = pd.DataFrame(data.y)\n",
        "df_x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 30,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "W4iU8dx0DpFR",
        "outputId": "5e7ea7eb-873f-498c-b735-bc9873f8784d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "MLP(\n",
            "  (linear1): Linear(in_features=1433, out_features=16, bias=True)\n",
            "  (linear2): Linear(in_features=16, out_features=7, bias=True)\n",
            ")\n",
            "Epoch   0 | Train Loss: 1.959 | Train Acc: 14.29% | Val Loss: 2.00 | Val Acc: 12.40%\n",
            "Epoch  20 | Train Loss: 0.110 | Train Acc: 100.00% | Val Loss: 1.46 | Val Acc: 49.40%\n",
            "Epoch  40 | Train Loss: 0.014 | Train Acc: 100.00% | Val Loss: 1.44 | Val Acc: 51.00%\n",
            "Epoch  60 | Train Loss: 0.008 | Train Acc: 100.00% | Val Loss: 1.40 | Val Acc: 53.80%\n",
            "Epoch  80 | Train Loss: 0.008 | Train Acc: 100.00% | Val Loss: 1.37 | Val Acc: 55.40%\n",
            "Epoch 100 | Train Loss: 0.009 | Train Acc: 100.00% | Val Loss: 1.34 | Val Acc: 54.60%\n",
            "\n",
            "MLP test accuracy: 53.40%\n"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "torch.manual_seed(0)\n",
        "from torch.nn import Linear\n",
        "import torch.nn.functional as F\n",
        "\n",
        "\n",
        "def accuracy(y_pred, y_true):\n",
        "    \"\"\"Calculate accuracy.\"\"\"\n",
        "    return torch.sum(y_pred == y_true) / len(y_true)\n",
        "\n",
        "\n",
        "class MLP(torch.nn.Module):\n",
        "    \"\"\"Multilayer Perceptron\"\"\"\n",
        "    def __init__(self, dim_in, dim_h, dim_out):\n",
        "        super().__init__()\n",
        "        self.linear1 = Linear(dim_in, dim_h)\n",
        "        self.linear2 = Linear(dim_h, dim_out)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.linear1(x)\n",
        "        x = torch.relu(x)\n",
        "        x = self.linear2(x)\n",
        "        return F.log_softmax(x, dim=1)\n",
        "\n",
        "    def fit(self, data, epochs):\n",
        "        criterion = torch.nn.CrossEntropyLoss()\n",
        "        optimizer = torch.optim.Adam(self.parameters(),\n",
        "                                          lr=0.01,\n",
        "                                          weight_decay=5e-4)\n",
        "\n",
        "        self.train()\n",
        "        for epoch in range(epochs+1):\n",
        "            optimizer.zero_grad()\n",
        "            out = self(data.x)\n",
        "            loss = criterion(out[data.train_mask], data.y[data.train_mask])\n",
        "            acc = accuracy(out[data.train_mask].argmax(dim=1),\n",
        "                          data.y[data.train_mask])\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "\n",
        "            if(epoch % 20 == 0):\n",
        "                val_loss = criterion(out[data.val_mask], data.y[data.val_mask])\n",
        "                val_acc = accuracy(out[data.val_mask].argmax(dim=1),\n",
        "                                  data.y[data.val_mask])\n",
        "                print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f} | Train Acc:'\n",
        "                      f' {acc*100:>5.2f}% | Val Loss: {val_loss:.2f} | '\n",
        "                      f'Val Acc: {val_acc*100:.2f}%')\n",
        "\n",
        "    @torch.no_grad()      \n",
        "    def test(self, data):\n",
        "        self.eval()\n",
        "        out = self(data.x)\n",
        "        acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])\n",
        "        return acc\n",
        "\n",
        "# Create MLP model\n",
        "mlp = MLP(dataset.num_features, 16, dataset.num_classes)\n",
        "print(mlp)\n",
        "\n",
        "# Train\n",
        "mlp.fit(data, epochs=100)\n",
        "\n",
        "# Test\n",
        "acc = mlp.test(data)\n",
        "print(f'\\nMLP test accuracy: {acc*100:.2f}%')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 31,
      "metadata": {
        "id": "OTiqwwFXQlVA"
      },
      "outputs": [],
      "source": [
        "class VanillaGNNLayer(torch.nn.Module):\n",
        "    def __init__(self, dim_in, dim_out):\n",
        "        super().__init__()\n",
        "        self.linear = Linear(dim_in, dim_out, bias=False)\n",
        "\n",
        "    def forward(self, x, adjacency):\n",
        "        x = self.linear(x)\n",
        "        x = torch.sparse.mm(adjacency, x)\n",
        "        return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 32,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "u9eP_1qtROua",
        "outputId": "540c47fe-d910-4dd4-b310-b96eaf03ffc1"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "tensor([[1., 0., 0.,  ..., 0., 0., 0.],\n",
              "        [0., 1., 1.,  ..., 0., 0., 0.],\n",
              "        [0., 1., 1.,  ..., 0., 0., 0.],\n",
              "        ...,\n",
              "        [0., 0., 0.,  ..., 1., 0., 0.],\n",
              "        [0., 0., 0.,  ..., 0., 1., 1.],\n",
              "        [0., 0., 0.,  ..., 0., 1., 1.]])"
            ]
          },
          "execution_count": 32,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "from torch_geometric.utils import to_dense_adj\n",
        "\n",
        "adjacency = to_dense_adj(data.edge_index)[0]\n",
        "adjacency += torch.eye(len(adjacency))\n",
        "adjacency"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 33,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FmHhedhcK9vi",
        "outputId": "286ea8be-b6eb-4ec0-fb4e-6b23c9e68479"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "VanillaGNN(\n",
            "  (gnn1): VanillaGNNLayer(\n",
            "    (linear): Linear(in_features=1433, out_features=16, bias=False)\n",
            "  )\n",
            "  (gnn2): VanillaGNNLayer(\n",
            "    (linear): Linear(in_features=16, out_features=7, bias=False)\n",
            "  )\n",
            ")\n",
            "Epoch   0 | Train Loss: 1.991 | Train Acc: 15.71% | Val Loss: 2.11 | Val Acc: 9.40%\n",
            "Epoch  20 | Train Loss: 0.065 | Train Acc: 99.29% | Val Loss: 1.47 | Val Acc: 76.80%\n",
            "Epoch  40 | Train Loss: 0.014 | Train Acc: 100.00% | Val Loss: 2.11 | Val Acc: 75.40%\n",
            "Epoch  60 | Train Loss: 0.007 | Train Acc: 100.00% | Val Loss: 2.22 | Val Acc: 75.40%\n",
            "Epoch  80 | Train Loss: 0.004 | Train Acc: 100.00% | Val Loss: 2.20 | Val Acc: 76.80%\n",
            "Epoch 100 | Train Loss: 0.003 | Train Acc: 100.00% | Val Loss: 2.19 | Val Acc: 77.00%\n",
            "\n",
            "GNN test accuracy: 76.60%\n"
          ]
        }
      ],
      "source": [
        "class VanillaGNN(torch.nn.Module):\n",
        "    \"\"\"Vanilla Graph Neural Network\"\"\"\n",
        "    def __init__(self, dim_in, dim_h, dim_out):\n",
        "        super().__init__()\n",
        "        self.gnn1 = VanillaGNNLayer(dim_in, dim_h)\n",
        "        self.gnn2 = VanillaGNNLayer(dim_h, dim_out)\n",
        "\n",
        "    def forward(self, x, adjacency):\n",
        "        h = self.gnn1(x, adjacency)\n",
        "        h = torch.relu(h)\n",
        "        h = self.gnn2(h, adjacency)\n",
        "        return F.log_softmax(h, dim=1)\n",
        "\n",
        "    def fit(self, data, epochs):\n",
        "        criterion = torch.nn.CrossEntropyLoss()\n",
        "        optimizer = torch.optim.Adam(self.parameters(),\n",
        "                                      lr=0.01,\n",
        "                                      weight_decay=5e-4)\n",
        "\n",
        "        self.train()\n",
        "        for epoch in range(epochs+1):\n",
        "            optimizer.zero_grad()\n",
        "            out = self(data.x, adjacency)\n",
        "            loss = criterion(out[data.train_mask], data.y[data.train_mask])\n",
        "            acc = accuracy(out[data.train_mask].argmax(dim=1),\n",
        "                          data.y[data.train_mask])\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "\n",
        "            if(epoch % 20 == 0):\n",
        "                val_loss = criterion(out[data.val_mask], data.y[data.val_mask])\n",
        "                val_acc = accuracy(out[data.val_mask].argmax(dim=1),\n",
        "                                  data.y[data.val_mask])\n",
        "                print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f} | Train Acc:'\n",
        "                      f' {acc*100:>5.2f}% | Val Loss: {val_loss:.2f} | '\n",
        "                      f'Val Acc: {val_acc*100:.2f}%')\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def test(self, data):\n",
        "        self.eval()\n",
        "        out = self(data.x, adjacency)\n",
        "        acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])\n",
        "        return acc\n",
        "\n",
        "# Create the Vanilla GNN model\n",
        "gnn = VanillaGNN(dataset.num_features, 16, dataset.num_classes)\n",
        "print(gnn)\n",
        "\n",
        "# Train\n",
        "gnn.fit(data, epochs=100)\n",
        "\n",
        "# Test\n",
        "acc = gnn.test(data)\n",
        "print(f'\\nGNN test accuracy: {acc*100:.2f}%')"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Facebook Page-Page dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 34,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "MLP(\n",
            "  (linear1): Linear(in_features=128, out_features=16, bias=True)\n",
            "  (linear2): Linear(in_features=16, out_features=4, bias=True)\n",
            ")\n",
            "Epoch   0 | Train Loss: 1.401 | Train Acc: 28.11% | Val Loss: 1.40 | Val Acc: 28.91%\n",
            "Epoch  20 | Train Loss: 0.671 | Train Acc: 73.47% | Val Loss: 0.68 | Val Acc: 72.94%\n",
            "Epoch  40 | Train Loss: 0.579 | Train Acc: 76.95% | Val Loss: 0.61 | Val Acc: 74.89%\n",
            "Epoch  60 | Train Loss: 0.549 | Train Acc: 78.17% | Val Loss: 0.60 | Val Acc: 75.49%\n",
            "Epoch  80 | Train Loss: 0.533 | Train Acc: 78.74% | Val Loss: 0.60 | Val Acc: 75.44%\n",
            "Epoch 100 | Train Loss: 0.520 | Train Acc: 79.28% | Val Loss: 0.60 | Val Acc: 75.44%\n",
            "\n",
            "MLP test accuracy: 75.17%\n",
            "\n",
            "VanillaGNN(\n",
            "  (gnn1): VanillaGNNLayer(\n",
            "    (linear): Linear(in_features=128, out_features=16, bias=False)\n",
            "  )\n",
            "  (gnn2): VanillaGNNLayer(\n",
            "    (linear): Linear(in_features=16, out_features=4, bias=False)\n",
            "  )\n",
            ")\n",
            "Epoch   0 | Train Loss: 176.683 | Train Acc: 28.31% | Val Loss: 173.10 | Val Acc: 28.41%\n",
            "Epoch  20 | Train Loss: 6.675 | Train Acc: 79.69% | Val Loss: 4.49 | Val Acc: 80.19%\n",
            "Epoch  40 | Train Loss: 2.284 | Train Acc: 82.15% | Val Loss: 1.60 | Val Acc: 83.64%\n",
            "Epoch  60 | Train Loss: 1.233 | Train Acc: 83.91% | Val Loss: 1.06 | Val Acc: 84.34%\n",
            "Epoch  80 | Train Loss: 1.182 | Train Acc: 84.17% | Val Loss: 0.97 | Val Acc: 83.64%\n",
            "Epoch 100 | Train Loss: 2.855 | Train Acc: 83.74% | Val Loss: 1.57 | Val Acc: 83.64%\n",
            "\n",
            "GNN test accuracy: 83.92%\n"
          ]
        }
      ],
      "source": [
        "# Dataset\n",
        "dataset = FacebookPagePage(root=\".\")\n",
        "data = dataset[0]\n",
        "data.train_mask = range(18000)\n",
        "data.val_mask = range(18001, 20000)\n",
        "data.test_mask = range(20001, 22470)\n",
        "\n",
        "# Adjacency matrix\n",
        "adjacency = to_dense_adj(data.edge_index)[0]\n",
        "adjacency += torch.eye(len(adjacency))\n",
        "adjacency\n",
        "\n",
        "# MLP\n",
        "mlp = MLP(dataset.num_features, 16, dataset.num_classes)\n",
        "print(mlp)\n",
        "mlp.fit(data, epochs=100)\n",
        "acc = mlp.test(data)\n",
        "print(f'\\nMLP test accuracy: {acc*100:.2f}%\\n')\n",
        "\n",
        "# GCN\n",
        "gnn = VanillaGNN(dataset.num_features, 16, dataset.num_classes)\n",
        "print(gnn)\n",
        "gnn.fit(data, epochs=100)\n",
        "acc = gnn.test(data)\n",
        "print(f'\\nGNN test accuracy: {acc*100:.2f}%')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "book",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.15 (default, Nov 24 2022, 14:38:14) [MSC v.1916 64 bit (AMD64)]"
    },
    "vscode": {
      "interpreter": {
        "hash": "3556630122da5213751af4465d61fcf5a52cd22515d400aee51118aaa1721248"
      }
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "327acf50ddb34eb38feab845c8da82e8": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f6f770342e9e403e9ab5f73623a8efb6",
            "placeholder": "​",
            "style": "IPY_MODEL_43717136a906424fb303ee45c7242116",
            "value": "Computing transition probabilities: 100%"
          }
        },
        "43717136a906424fb303ee45c7242116": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "51dd30f90e234e4aa17374bf2c2419f2": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_327acf50ddb34eb38feab845c8da82e8",
              "IPY_MODEL_b0c83e7abab54f5792a37796cf195bc4",
              "IPY_MODEL_5983b3c6970644d183eb007ba105fcdb"
            ],
            "layout": "IPY_MODEL_c2aaf7ec538b49feaafbbe6962962454"
          }
        },
        "5983b3c6970644d183eb007ba105fcdb": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_91af1f04a1dd4cadb4a231aa37d5af52",
            "placeholder": "​",
            "style": "IPY_MODEL_d9f1466ed11d479baf5878db3748cadb",
            "value": " 22470/22470 [01:34&lt;00:00, 203.42it/s]"
          }
        },
        "91af1f04a1dd4cadb4a231aa37d5af52": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "999663613ed7491b9a12992460611248": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b0c83e7abab54f5792a37796cf195bc4": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_999663613ed7491b9a12992460611248",
            "max": 22470,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_b4fe165a8fcf4812b18f7ad54b7a7016",
            "value": 22470
          }
        },
        "b4fe165a8fcf4812b18f7ad54b7a7016": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "c2aaf7ec538b49feaafbbe6962962454": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d9f1466ed11d479baf5878db3748cadb": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f6f770342e9e403e9ab5f73623a8efb6": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
